{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "\n",
    "from fastai.io import *\n",
    "from fastai.conv_learner import *\n",
    "\n",
    "from fastai.column_data import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We're going to download the collected works of Nietzsche to use as our data for this class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH = 'data/nietzsche/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "nietzsche.txt: 606KB [00:00, 1.86MB/s]                            "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "corpus length: 600893\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "get_data('https://s3.amazonaws.com/text-datasets/nietzsche.txt', f'{PATH}nietzsche.txt')\n",
    "text = open(f'{PATH}nietzsche.txt').read()\n",
    "print('corpus length:', len(text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'PREFACE\\n\\n\\nSUPPOSING that Truth is a woman--what then? Is there not ground\\nfor suspecting that all philosophers, in so far as they have been\\ndogmatists, have failed to understand women--that the terrible\\nseriousness and clumsy importunity with which they have usually paid\\ntheir addresses to Truth, have been unskilled and unseemly methods for\\nwinning a woman? Certainly she has never allowed herself '"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text[:400]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total chars: 85\n"
     ]
    }
   ],
   "source": [
    "chars = sorted(list(set(text)))\n",
    "vocab_size = len(chars) + 1\n",
    "print('total chars:', vocab_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sometimes it's useful to have a zero value in the dataset, e.g. for padding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n !\"\\'(),-.0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxy'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chars.insert(0, '\\0')\n",
    "''.join(chars[1:-6])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Map from chars to indices and back again"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "char_indices = {c: i for i, c in enumerate(chars)}\n",
    "indices_char = {i: c for i, c in enumerate(chars)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*idx* will be the data we use from now on - it simply converts all the characters to their index (based on the mapping above)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[40, 42, 29, 30, 25, 27, 29, 1, 1, 1]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx = [char_indices[c] for c in text]\n",
    "idx[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'PREFACE\\n\\n\\nSUPPOSING that Truth is a woman--what then? Is there not gro'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "''.join(indices_char[i] for i in idx[:70])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Three char model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Create inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "Create a list of every 4th character, starting at the 0th, 1st, 2nd, then 3rd characters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "cs=3\n",
    "c1_dat = [idx[i]     for i in range(0, len(idx) - cs, cs)]\n",
    "c2_dat = [idx[i + 1] for i in range(0, len(idx) - cs, cs)]\n",
    "c3_dat = [idx[i + 2] for i in range(0, len(idx) - cs, cs)]\n",
    "c4_dat = [idx[i + 3] for i in range(0, len(idx) - cs, cs)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "Our inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "x1 = np.stack(c1_dat)\n",
    "x2 = np.stack(c2_dat)\n",
    "x3 = np.stack(c3_dat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "Our output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "y = np.stack(c4_dat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "The first 4 inputs and outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "hidden": true,
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([40, 30, 29,  1]), array([42, 25,  1, 43]), array([29, 27,  1, 45]))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1[:4], x2[:4], x3[:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([30, 29,  1, 40])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y[:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((200297,), (200297,))"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1.shape, y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create and train model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pick a size for our hidden state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_hidden = 256"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The number of latent factors to create (i.e. the size of the embedding matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_fac = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Char3Model(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac):\n",
    "        super().__init__()\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "\n",
    "        # The 'green arrow' from our diagram - the layer operation from input to hidden\n",
    "        self.l_in = nn.Linear(n_fac, n_hidden)\n",
    "\n",
    "        # The 'orange arrow' from our diagram - the layer operation from hidden to hidden\n",
    "        self.l_hidden = nn.Linear(n_hidden, n_hidden)\n",
    "        \n",
    "        # The 'blue arrow' from our diagram - the layer operation from hidden to output\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        \n",
    "    def forward(self, c1, c2, c3):\n",
    "        in1 = F.relu(self.l_in(self.e(c1)))\n",
    "        in2 = F.relu(self.l_in(self.e(c2)))\n",
    "        in3 = F.relu(self.l_in(self.e(c3)))\n",
    "        \n",
    "        h = V(torch.zeros(in1.size()).cuda())\n",
    "        h = F.tanh(self.l_hidden(h + in1))\n",
    "        h = F.tanh(self.l_hidden(h + in2))\n",
    "        h = F.tanh(self.l_hidden(h + in3))\n",
    "        \n",
    "        return F.log_softmax(self.l_out(h))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "md = ColumnarModelData.from_arrays('.', [-1], np.stack([x1,x2,x3], axis=1), y, bs=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = Char3Model(vocab_size, n_fac).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "it = iter(md.trn_dl)\n",
    "*xs, yt = next(it)\n",
    "t = m(*V(xs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = optim.Adam(m.parameters(), 1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0016d1f98f524ac69331c573514ab840",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                              \n",
      "    0      2.096911   1.226288  \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.22629])]"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 1, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_lrs(opt, 0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9e6c99d85a3d4aee8b737eda3347ad32",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                              \n",
      "    0      1.84777    0.387795  \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([0.3878])]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 1, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Test model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "def get_next(inp):\n",
    "    idxs = T(np.array([char_indices[c] for c in inp]))\n",
    "    p = m(*VV(idxs))\n",
    "    i = np.argmax(to_np(p))\n",
    "    return chars[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "hidden": true,
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'T'"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('y. ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "hidden": true,
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'a'"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('ppl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'e'"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next(' th')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' '"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('and')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Our first RNN!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Create inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "This is the size of our unrolled RNN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "cs = 8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "For each of 0 through 7, create a list of every 8th character with that starting point. These will be the 8 inputs to our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "c_in_dat = [[idx[i + j] for i in range(cs)] for j in range(len(idx) - cs)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "Then create a list of the next character in each of these series. This will be the labels for our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "c_out_dat = [idx[j + cs] for j in range(len(idx) - cs)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "xs = np.stack(c_in_dat, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(600885, 8)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "y = np.stack(c_out_dat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "So each column below is one series of 8 characters from the text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[40, 42, 29, 30, 25, 27, 29,  1],\n",
       "       [42, 29, 30, 25, 27, 29,  1,  1],\n",
       "       [29, 30, 25, 27, 29,  1,  1,  1],\n",
       "       [30, 25, 27, 29,  1,  1,  1, 43],\n",
       "       [25, 27, 29,  1,  1,  1, 43, 45],\n",
       "       [27, 29,  1,  1,  1, 43, 45, 40],\n",
       "       [29,  1,  1,  1, 43, 45, 40, 40],\n",
       "       [ 1,  1,  1, 43, 45, 40, 40, 39]])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xs[:cs, :cs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "...and this is the next character after each sequence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 1,  1, 43, 45, 40, 40, 39, 43])"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y[:cs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_idx = get_cv_idxs(len(idx) - cs - 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "md = ColumnarModelData.from_arrays('.', val_idx, xs, y, bs=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharLoopModel(nn.Module):\n",
    "    # This is an RNN!\n",
    "    def __init__(self, vocab_size, n_fac):\n",
    "        super().__init__()\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.l_in = nn.Linear(n_fac, n_hidden)\n",
    "        self.l_hidden = nn.Linear(n_hidden, n_hidden)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        \n",
    "    def forward(self, *cs):\n",
    "        bs = cs[0].size(0)\n",
    "        h = V(torch.zeros(bs, n_hidden).cuda())\n",
    "        for c in cs:\n",
    "            inp = F.relu(self.l_in(self.e(c)))\n",
    "            h = F.tanh(self.l_hidden(h + inp))\n",
    "        \n",
    "        return F.log_softmax(self.l_out(h), dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharLoopModel(vocab_size, n_fac).cuda()\n",
    "opt = optim.Adam(m.parameters(), 1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bd211011246c44f1ad5ad09226fde1a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                              \n",
      "    0      2.022557   2.007089  \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([2.00709])]"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 1, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_lrs(opt, 0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0159c9f281c24e5aa61ee6b508c50b28",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                              \n",
      "    0      1.729759   1.729243  \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.72924])]"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 1, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharLoopConcatModel(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac):\n",
    "        super().__init__()\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.l_in = nn.Linear(n_fac + n_hidden, n_hidden)\n",
    "        self.l_hidden = nn.Linear(n_hidden, n_hidden)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        \n",
    "    def forward(self, *cs):\n",
    "        bs = cs[0].size(0)\n",
    "        h = V(torch.zeros(bs, n_hidden).cuda())\n",
    "        for c in cs:\n",
    "            inp = torch.cat((h, self.e(c)), 1)\n",
    "            inp = F.relu(self.l_in(inp))\n",
    "            h = F.tanh(self.l_hidden(inp))\n",
    "        \n",
    "        return F.log_softmax(self.l_out(h), dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharLoopConcatModel(vocab_size, n_fac).cuda()\n",
    "opt = optim.Adam(m.parameters(), 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "it = iter(md.trn_dl)\n",
    "*xs, yt = next(it)\n",
    "t = m(*V(xs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5d86391d2a0142e597757eeb507aed2e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                              \n",
      "    0      1.873534   1.854928  \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.85493])]"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 1, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_lrs(opt, 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d33b6c91c481420ab3cbf14f38935cb5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                              \n",
      "    0      1.75918    1.758367  \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.75837])]"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 1, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Test model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "def get_next(inp):\n",
    "    idxs = T(np.array([char_indices[c] for c in inp]))\n",
    "    p = m(*VV(idxs))\n",
    "    i = np.argmax(to_np(p))\n",
    "    return chars[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'e'"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('for thos')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'t'"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('part of ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'n'"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('queens a')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## RNN with pytorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharRnn(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac):\n",
    "        super().__init__()\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.rnn = nn.RNN(n_fac, n_hidden)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        \n",
    "    def forward(self, *cs):\n",
    "        bs = cs[0].size(0)\n",
    "        h = V(torch.zeros(1, bs, n_hidden))\n",
    "        inp = self.e(torch.stack(cs))\n",
    "        outp,h = self.rnn(inp, h)\n",
    "        \n",
    "        return F.log_softmax(self.l_out(outp[-1]), dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharRnn(vocab_size, n_fac).cuda()\n",
    "opt = optim.Adam(m.parameters(), 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "it = iter(md.trn_dl)\n",
    "*xs, yt = next(it)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 512, 42])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = m.e(V(torch.stack(xs)))\n",
    "t.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([8, 512, 256]), torch.Size([1, 512, 256]))"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ht = V(torch.zeros(1, 512, n_hidden))\n",
    "outp, hn = m.rnn(t, ht)\n",
    "outp.size(), hn.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([512, 85])"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = m(*V(xs))\n",
    "t.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8d246d3e8fc5422b8f34dc9f90ff8655",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                              \n",
      "    0      1.874941   1.848242  \n",
      "    1      1.686629   1.681533                              \n",
      "    2      1.59148    1.597262                              \n",
      "    3      1.537338   1.551838                              \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.55184])]"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 4, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_lrs(opt, 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f871c20555ee4ffa8c50b8526987a20a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                              \n",
      "    0      1.473922   1.514578  \n",
      "    1      1.46794    1.508425                              \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.50842])]"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 2, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Test model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "def get_next(inp):\n",
    "    idxs = T(np.array([char_indices[c] for c in inp]))\n",
    "    p = m(*VV(idxs))\n",
    "    i = np.argmax(to_np(p))\n",
    "    return chars[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'e'"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('for thos')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "def get_next_n(inp, n):\n",
    "    res = inp\n",
    "    for i in range(n):\n",
    "        c = get_next(inp)\n",
    "        res += c\n",
    "        inp = inp[1:] + c\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'for those and the same to the same to the same t'"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next_n('for thos', 40)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Multi-output model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "Let's take non-overlapping sets of characters this time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "c_in_dat = [[idx[i + j] for i in range(cs)] for j in range(0, len(idx) - cs - 1, cs)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "Then create the exact same thing, offset by 1, as our labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "c_out_dat = [[idx[i + j] for i in range(cs)] for j in range(1, len(idx) - cs, cs)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(75111, 8)"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xs = np.stack(c_in_dat)\n",
    "xs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(75111, 8)"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ys = np.stack(c_out_dat)\n",
    "ys.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[40, 42, 29, 30, 25, 27, 29,  1],\n",
       "       [ 1,  1, 43, 45, 40, 40, 39, 43],\n",
       "       [33, 38, 31,  2, 73, 61, 54, 73],\n",
       "       [ 2, 44, 71, 74, 73, 61,  2, 62],\n",
       "       [72,  2, 54,  2, 76, 68, 66, 54],\n",
       "       [67,  9,  9, 76, 61, 54, 73,  2],\n",
       "       [73, 61, 58, 67, 24,  2, 33, 72],\n",
       "       [ 2, 73, 61, 58, 71, 58,  2, 67]])"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xs[:cs, :cs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[42, 29, 30, 25, 27, 29,  1,  1],\n",
       "       [ 1, 43, 45, 40, 40, 39, 43, 33],\n",
       "       [38, 31,  2, 73, 61, 54, 73,  2],\n",
       "       [44, 71, 74, 73, 61,  2, 62, 72],\n",
       "       [ 2, 54,  2, 76, 68, 66, 54, 67],\n",
       "       [ 9,  9, 76, 61, 54, 73,  2, 73],\n",
       "       [61, 58, 67, 24,  2, 33, 72,  2],\n",
       "       [73, 61, 58, 71, 58,  2, 67, 68]])"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ys[:cs, :cs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_idx = get_cv_idxs(len(xs) - cs - 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "md = ColumnarModelData.from_arrays('.', val_idx, xs, ys, bs=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharSeqRnn(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac):\n",
    "        super().__init__()\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.rnn = nn.RNN(n_fac, n_hidden)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        \n",
    "    def forward(self, *cs):\n",
    "        bs = cs[0].size(0)\n",
    "        h = V(torch.zeros(1, bs, n_hidden))\n",
    "        inp = self.e(torch.stack(cs))\n",
    "        outp,h = self.rnn(inp, h)\n",
    "        return F.log_softmax(self.l_out(outp), dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharSeqRnn(vocab_size, n_fac).cuda()\n",
    "opt = optim.Adam(m.parameters(), 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "it = iter(md.trn_dl)\n",
    "*xst, yt = next(it)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nll_loss_seq(inp, targ):\n",
    "    sl, bs, nh = inp.size()\n",
    "    targ = targ.transpose(0,1).contiguous().view(-1)\n",
    "    return F.nll_loss(inp.view(-1, nh), targ)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ebca36bd11084b12afd6fcdd4281a9cc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                              \n",
      "    0      2.606978   2.421097  \n",
      "    1      2.300751   2.210837                              \n",
      "    2      2.148818   2.094267                              \n",
      "    3      2.05272    2.018827                              \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([2.01883])]"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 4, opt, nll_loss_seq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_lrs(opt, 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a7b27a392bef4a778f044422b4e89daa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                             \n",
      "    0      2.003635   2.00427   \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([2.00427])]"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 1, opt, nll_loss_seq)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### Identity init!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "m = CharSeqRnn(vocab_size, n_fac).cuda()\n",
    "opt = optim.Adam(m.parameters(), 1e-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "    1     0     0  ...      0     0     0\n",
       "    0     1     0  ...      0     0     0\n",
       "    0     0     1  ...      0     0     0\n",
       "       ...          ⋱          ...       \n",
       "    0     0     0  ...      1     0     0\n",
       "    0     0     0  ...      0     1     0\n",
       "    0     0     0  ...      0     0     1\n",
       "[torch.cuda.FloatTensor of size 256x256 (GPU 0)]"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m.rnn.weight_hh_l0.data.copy_(torch.eye(n_hidden))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {
    "hidden": true,
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "94afbe7a5c5542339b0b34947f18c2ca",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                              \n",
      "    0      2.380318   2.210412  \n",
      "    1      2.124148   2.065752                              \n",
      "    2      2.023889   1.99549                               \n",
      "    3      1.973753   1.964885                              \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.96489])]"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 4, opt, nll_loss_seq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "set_lrs(opt, 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "101115995bbd4b4f8e160a98ada97081",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                              \n",
      "    0      1.885918   1.898218  \n",
      "    1      1.877252   1.891765                              \n",
      "    2      1.868084   1.886029                              \n",
      "    3      1.860838   1.882068                              \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.88207])]"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 4, opt, nll_loss_seq)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stateful model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0m\u001b[01;34mmodels\u001b[0m/  nietzsche.txt  \u001b[01;34mtrn\u001b[0m/  \u001b[01;34mval\u001b[0m/\r\n"
     ]
    }
   ],
   "source": [
    "from torchtext import vocab, data\n",
    "from fastai.nlp import *\n",
    "from fastai.lm_rnn import *\n",
    "\n",
    "PATH='data/nietzsche/'\n",
    "\n",
    "TRN_PATH = 'trn/'\n",
    "VAL_PATH = 'val/'\n",
    "TRN = f'{PATH}{TRN_PATH}'\n",
    "VAL = f'{PATH}{VAL_PATH}'\n",
    "\n",
    "%ls {PATH}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trn.txt\r\n"
     ]
    }
   ],
   "source": [
    "%ls {PATH}trn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1153, 55, 1, 590960)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TEXT = data.Field(lower=True, tokenize=list)\n",
    "bs=64\n",
    "bptt=8\n",
    "n_fac=42\n",
    "n_hidden=256\n",
    "\n",
    "FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)\n",
    "md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=3)\n",
    "\n",
    "len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharSeqStatefulRnn(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac, bs):\n",
    "        self.vocab_size = vocab_size\n",
    "        super().__init__()\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.rnn = nn.RNN(n_fac, n_hidden)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        self.init_hidden(bs)\n",
    "        \n",
    "    def forward(self, cs):\n",
    "        bs = cs[0].size(0)\n",
    "        if self.h.size(1) != bs: self.init_hidden(bs)\n",
    "        outp,h = self.rnn(self.e(cs), self.h)\n",
    "        self.h = repackage_var(h)\n",
    "        return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)\n",
    "    \n",
    "    def init_hidden(self, bs): self.h = V(torch.zeros(1, bs, n_hidden))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharSeqStatefulRnn(md.nt, n_fac, 512).cuda()\n",
    "opt = optim.Adam(m.parameters(), 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ee0991a11c20464597ba15ea7ac9d94c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                                 \n",
      "    0      1.799179   1.795937  \n",
      "    1      1.635381   1.644064                                 \n",
      "    2      1.55905    1.571078                                 \n",
      "    3      1.517546   1.52435                                  \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.52435])]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 4, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2b2e5236459b4232994dcfb43c2341c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                                 \n",
      "    0      1.439247   1.473334  \n",
      "    1      1.442974   1.466567                                 \n",
      "    2      1.437667   1.460651                                 \n",
      "    3      1.438405   1.455288                                 \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.45529])]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "set_lrs(opt, 1e-4)\n",
    "\n",
    "fit(m, md, 4, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### RNN loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharSeqStatefulRnn2(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac, bs):\n",
    "        super().__init__()\n",
    "        self.vocab_size = vocab_size\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.rnn = nn.RNNCell(n_fac, n_hidden)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        self.init_hidden(bs)\n",
    "        \n",
    "    def forward(self, cs):\n",
    "        bs = cs[0].size(0)\n",
    "        if self.h.size(1) != bs: self.init_hidden(bs)\n",
    "        outp = []\n",
    "        o = self.h\n",
    "        for c in cs: \n",
    "            o = self.rnn(self.e(c), o)\n",
    "            outp.append(o)\n",
    "        outp = self.l_out(torch.stack(outp))\n",
    "        self.h = repackage_var(o)\n",
    "        return F.log_softmax(outp, dim=-1).view(-1, self.vocab_size)\n",
    "    \n",
    "    def init_hidden(self, bs): self.h = V(torch.zeros(1, bs, n_hidden))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharSeqStatefulRnn2(md.nt, n_fac, 512).cuda()\n",
    "opt = optim.Adam(m.parameters(), 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ffec6007de5b4506bc896c3c5b23a826",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                                 \n",
      "    0      1.80518    1.797154  \n",
      "    1      1.631868   1.635422                                 \n",
      "    2      1.559168   1.566188                                 \n",
      "    3      1.515277   1.51792                                  \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.51792])]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 4, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharSeqStatefulGRU(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac, bs):\n",
    "        super().__init__()\n",
    "        self.vocab_size = vocab_size\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.rnn = nn.GRU(n_fac, n_hidden)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        self.init_hidden(bs)\n",
    "        \n",
    "    def forward(self, cs):\n",
    "        bs = cs[0].size(0)\n",
    "        if self.h.size(1) != bs: self.init_hidden(bs)\n",
    "        outp,h = self.rnn(self.e(cs), self.h)\n",
    "        self.h = repackage_var(h)\n",
    "        return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)\n",
    "    \n",
    "    def init_hidden(self, bs): self.h = V(torch.zeros(1, bs, n_hidden))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharSeqStatefulGRU(md.nt, n_fac, 512).cuda()\n",
    "\n",
    "opt = optim.Adam(m.parameters(), 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7a0ce69b7dde44c585bbe09191d6befc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=6), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                                 \n",
      "    0      1.675136   1.674389  \n",
      "    1      1.508034   1.506611                                 \n",
      "    2      1.434869   1.437619                                 \n",
      "    3      1.395137   1.395509                                 \n",
      "    4      1.361485   1.36442                                  \n",
      "    5      1.336734   1.337312                                 \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.33731])]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 6, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_lrs(opt, 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "815045d4f21e4f30bd23160229d61ea5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                                 \n",
      "    0      1.255207   1.291024  \n",
      "    1      1.252789   1.282952                                 \n",
      "    2      1.253178   1.277432                                 \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.27743])]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 3, opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Putting it all together: LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai import sgdr\n",
    "n_hidden = 512"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharSeqStatefulLSTM(nn.Module):\n",
    "    def __init__(self, vocab_size, n_fac, bs, nl):\n",
    "        super().__init__()\n",
    "        self.vocab_size,self.nl = vocab_size,nl\n",
    "        self.e = nn.Embedding(vocab_size, n_fac)\n",
    "        self.rnn = nn.LSTM(n_fac, n_hidden, nl, dropout=0.5)\n",
    "        self.l_out = nn.Linear(n_hidden, vocab_size)\n",
    "        self.init_hidden(bs)\n",
    "        \n",
    "    def forward(self, cs):\n",
    "        bs = cs[0].size(0)\n",
    "        if self.h[0].size(1) != bs: self.init_hidden(bs)\n",
    "        outp,h = self.rnn(self.e(cs), self.h)\n",
    "        self.h = repackage_var(h)\n",
    "        return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)\n",
    "    \n",
    "    def init_hidden(self, bs):\n",
    "        self.h = (V(torch.zeros(self.nl, bs, n_hidden)),\n",
    "                  V(torch.zeros(self.nl, bs, n_hidden)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = CharSeqStatefulLSTM(md.nt, n_fac, 512, 2).cuda()\n",
    "lo = LayerOptimizer(optim.Adam, m, 1e-2, 1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(f'{PATH}models', exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2bb4365fe6874457860962b86d28557b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                                 \n",
      "    0      1.735124   1.677932  \n",
      "    1      1.650891   1.592743                                 \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.59274])]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fit(m, md, 2, lo.opt, F.nll_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b56d0d6e15ce4edaa34deac400d60287",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=15), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch      trn_loss   val_loss                                 \n",
      "    0      1.479674   1.422725  \n",
      "    1      1.523124   1.467232                                 \n",
      "    2      1.406148   1.359548                                 \n",
      "    3      1.56533    1.495614                                 \n",
      "    4      1.487524   1.431022                                 \n",
      "    5      1.400005   1.348234                                 \n",
      "    6      1.342281   1.303726                                 \n",
      "    7      1.534022   1.470526                                 \n",
      "    8      1.499201   1.442911                                 \n",
      "    9      1.477629   1.426429                                 \n",
      "    10     1.44182    1.384572                                 \n",
      "    11     1.398508   1.342475                                 \n",
      "    12     1.352225   1.301657                                 \n",
      "    13     1.313135   1.264042                                 \n",
      "    14     1.283205   1.243538                                 \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[array([1.24354])]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "on_end = lambda sched, cycle: save_model(m, f'{PATH}models/cyc_{cycle}')\n",
    "cb = [CosAnneal(lo, len(md.trn_dl), cycle_mult=2, on_cycle_end=on_end)]\n",
    "fit(m, md, 2**4-1, lo.opt, F.nll_loss, callbacks=cb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_next(inp):\n",
    "    idxs = TEXT.numericalize(inp)\n",
    "    p = m(VV(idxs.transpose(0, 1)))\n",
    "    r = torch.multinomial(p[-1].exp(), 1)\n",
    "    return TEXT.vocab.itos[to_np(r)[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'u'"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_next('for thos')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_next_n(inp, n):\n",
    "    res = inp\n",
    "    for i in range(n):\n",
    "        c = get_next(inp)\n",
    "        res += c\n",
    "        inp = inp[1:] + c\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "for those oneself--however, a conscience is remotesmeans (to whichmust wish love of their religious dating and though there bean infliction or that is not the validity oneso or with relations, in men for display is civilizators one of which is anothers accased also again being andorality, truthful higher intercourse ofthe jews socratic enough to the deptium--low untruthsbetrough, in notion of knowledge in\n"
     ]
    }
   ],
   "source": [
    "print(get_next_n('for thos', 400))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "nav_menu": {},
  "toc": {
   "colors": {
    "hover_highlight": "#DAA520",
    "navigate_num": "#000000",
    "navigate_text": "#333333",
    "running_highlight": "#FF0000",
    "selected_highlight": "#FFD700",
    "sidebar_border": "#EEEEEE",
    "wrapper_background": "#FFFFFF"
   },
   "moveMenuLeft": true,
   "nav_menu": {
    "height": "216px",
    "width": "252px"
   },
   "navigate_menu": true,
   "number_sections": true,
   "sideBar": true,
   "threshold": 4,
   "toc_cell": false,
   "toc_section_display": "block",
   "toc_window_display": false,
   "widenNotebook": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
